from typing import Optional, Sequence, Callable, Dict

import numpy as np
from scipy.stats import pearsonr
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
from tqdm import tqdm

from conflictfree.grad_operator import ConFIGOperator
from conflictfree.momentum_operator import PseudoMomentumOperator
from conflictfree.utils import OrderedSliceSelector, get_gradient_vector

from neural_fields.gk_losses import get_integrals, integral_losses, spectra_losses
from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader
from neural_fields.nf_utils import plotND, plot_diag, sample_field


def train_nf(
    model: nn.Module,
    n_epochs: int,
    data: CycloneNFDataset,
    loader: CycloneNFDataLoader,
    device: torch.device,
    field_subsamples: Optional[Sequence[float]] = None,
    use_flux_fields: bool = False,
    use_spectral: bool = False,
    cheat_integral: bool = False,
    field_loss: bool = True,
    physical_loss: bool = False,
    optim: Optional[torch.optim.Optimizer] = None,
    sched: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
    aux_optim: Optional[torch.optim.Optimizer] = None,
    aux_sched: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
    integral_loss_weight: Dict[str, float] = {},
    physical_loss_weight: Dict[str, float] = {},
    core_model_loss: Optional[Callable] = None,
    use_tqdm: bool = True,
    use_print: bool = True,
    conditioned: bool = False,
    use_conflictfree: str = "none"
):
    model.train()

    torch.set_float32_matmul_precision("high")
    i = 0
    best_loss, best_model = -torch.inf, None
    data.to(device)
    loader.to(device)
    model.to(device)
    
    n_tasks = sum(1 for k in integral_loss_weight if integral_loss_weight[k] != 0.0)
    n_tasks += sum(1 for k in physical_loss_weight if physical_loss_weight[k] != 0.0)
    
    if use_conflictfree == "full":
        operator = ConFIGOperator()
    if use_conflictfree == "pseudo":
        operator = PseudoMomentumOperator(n_tasks)
        loss_selector = OrderedSliceSelector()
    
    for e in range(n_epochs):
        losses = {}

        if field_loss:
            if field_subsamples is not None:
                loader.subsample = field_subsamples[e]

            ll = []
            ploader = tqdm(loader, desc=f"Loss: {0.0:.6f}") if use_tqdm else loader
            for f, coords in ploader:
                kwcond = {}
                if conditioned:
                    coords, cond = coords
                    kwcond = {"cond": cond}
                pred_f = model(coords, **kwcond)
                # neural field loss
                loss = F.mse_loss(pred_f, f)
                if core_model_loss:
                    loss += core_model_loss(model)

                optim.zero_grad()
                loss.backward()
                optim.step()
                ll.append(loss.item())
                i += 1
                if i > 50 and use_tqdm:
                    ploader.set_description(f"Loss: {sum(ll) / len(ll):.6f}")
                    i = 0

            losses["train/loss"] = sum(ll) / len(ll)

            if sched is not None:
                sched.step()

        if physical_loss:
            if data.ndim == 6:
                timesteps = list(range(data.grid.shape[0]))
            else:
                timesteps = [None]
            for t in timesteps:
                (
                    int_losses,
                    (pred_df, gt_df),
                    (pred_phi, gt_phi),
                    (pred_eflux, gt_eflux),
                ) = integral_losses(
                    model,
                    data,
                    device,
                    use_flux_fields=use_flux_fields,
                    use_spectral=use_spectral,
                    cheat_integral=cheat_integral,
                    timestep=t,
                    return_fields=True,
                )
                spec_losses, _ = spectra_losses(
                    pred_df=pred_df,
                    pred_phi=pred_phi,
                    pred_eflux=pred_eflux,
                    gt_df=gt_df,
                    gt_phi=gt_phi,
                    gt_eflux=gt_eflux,
                    ds=data.ds,
                )

                int_losses = [
                    integral_loss_weight[k] * int_losses[f"{k} loss"]
                    for k in integral_loss_weight
                    if integral_loss_weight[k] != 0.0
                ]
                spec_losses = [
                    physical_loss_weight[k] * spec_losses[f"{k} loss"]
                    for k in physical_loss_weight
                    if physical_loss_weight[k] != 0.0
                ]
                aux_losses = int_losses + spec_losses
                
                if use_conflictfree == "pseudo":
                    aux_optim.zero_grad(set_to_none=True)
                    idx, loss_i = loss_selector.select(1, aux_losses)
                    loss_i.backward()
                    operator.update_gradient(model, idx, get_gradient_vector(model))
                elif use_conflictfree == "full":
                    grads = []
                    for loss_i in aux_losses:
                        aux_optim.zero_grad(set_to_none=True)
                        loss_i.backward(retain_graph=True)
                        grads.append(get_gradient_vector(model, none_grad_mode="zero"))
                    operator.update_gradient(model, grads)
                else:
                    aux_losses = sum(aux_losses)
                    aux_losses.backward()
                
                aux_optim.step()
                
                if aux_sched is not None:
                    aux_sched.step()

        # eval
        with torch.no_grad():
            if data.ndim == 6:
                timesteps = list(range(data.grid.shape[0]))
            else:
                timesteps = [None]
            eval_losses = []
            for t in timesteps:
                (
                    t_eval,
                    (pred_df, gt_df),
                    (pred_phi, gt_phi),
                    (pred_eflux, gt_eflux),
                ) = integral_losses(
                    model,
                    data,
                    device,
                    use_flux_fields=use_flux_fields,
                    use_spectral=use_spectral,
                    cheat_integral=cheat_integral,
                    timestep=t,
                    return_fields=True,
                )
                spec_losses, (gt_diag, pred_diag) = spectra_losses(
                    pred_df=pred_df,
                    pred_phi=pred_phi,
                    pred_eflux=pred_eflux,
                    gt_df=gt_df,
                    gt_phi=gt_phi,
                    gt_eflux=gt_eflux,
                    ds=data.ds,
                )

                eval_losses.append(t_eval | spec_losses)
            eval_losses = {
                k: sum([v[k] for v in eval_losses]) / len(eval_losses)
                for k in eval_losses[0]
            }
            losses.update({f"val/{k}": v for k, v in eval_losses.items()})
            losses["val/df psnr"] = 10 * torch.log10(
                gt_df.max() ** 2 / losses["val/df loss"] ** 2
            )
            losses["val/phi psnr"] = 10 * torch.log10(
                gt_phi.max() ** 2 / losses["val/phi mse"] ** 2
            )
            losses["val/ky pc"] = pearsonr(pred_diag["kyspec"].cpu(), gt_diag["kyspec"].cpu())[0]
            losses["val/q pc"] = pearsonr(pred_diag["qspec"].cpu(), gt_diag["qspec"].cpu())[0]

            if field_loss:
                curr_loss = losses["val/df psnr"]
            if physical_loss:
                curr_loss = losses["val/phi psnr"]

            if curr_loss > best_loss:
                best_loss = curr_loss
                best_model = deepcopy(model)

        if use_print:
            str_losses = ", ".join([f"{k}: {float(v):.6f}" for k, v in losses.items()])
            print(f"[{e}] {str_losses}")

    return model, best_model, losses


@torch.no_grad()
def eval_diagnose(
    data: CycloneNFDataset,
    device: torch.device,
    model: Optional[nn.Module] = None,
    pred_df: Optional[torch.Tensor] = None,
    T: Optional[int] = None,
    use_spectral: bool = False,
    cheat_integral: bool = False,
    metrics_only: bool = False,
):
    if model is not None:
        model.to(device)
        pred_df = sample_field(model, data, device, timestep=T)
    pred_df = pred_df.to(device)
    gt_df = data.full_df.clone()
    if T is not None:
        gt_df = gt_df[:, T]
    gt_phi, (_, gt_eflux, _) = get_integrals(
        gt_df.to(device), data, flux_fields=True, spectral_df=use_spectral
    )
    pred_phi, (pred_pflux, pred_eflux, _) = get_integrals(
        pred_df,
        data,
        flux_fields=True,
        spectral_df=use_spectral,
        phi_integral=not cheat_integral,
    )
    # diagnostics
    spec_losses, (gt_diag, pred_diag) = spectra_losses(
        pred_df.cpu(),
        pred_phi.cpu(),
        pred_eflux.cpu(),
        gt_df.cpu(),
        gt_phi.cpu(),
        gt_eflux.cpu(),
        data.ds,
    )

    mse = ((pred_df.cpu() - gt_df.cpu()) ** 2).mean()
    psnr = 10 * torch.log10(gt_df.max() ** 2 / mse**2)
    phi_mse = ((pred_phi - gt_phi) ** 2).mean()
    phi_psnr = 10 * torch.log10(gt_phi.max() ** 2 / phi_mse**2)
    print(
        f"df nmse: {mse / (gt_df.cpu() ** 2).mean():.2f}, "
        f"df psnr: {psnr.item():.2f}\n"
        f"pflux: {pred_pflux.sum():.2f}, "
        f"eflux: {pred_eflux.sum():.2f}, gt eflux {gt_eflux.sum():.2f}\n"
        f"phi nmse: {phi_mse / (gt_phi ** 2).mean():.2f}, "
        f"phi psnr: {phi_psnr:.2f}\n"
        f"kyspec L1: {spec_losses['kyspec loss']:.2f}, "
        f"kyspec mono: {spec_losses['kyspec monotonicity loss']:.2f}\n"
        f"qspec L1: {spec_losses['qspec loss']:.2f}, "
        f"qspec mono: {spec_losses['qspec monotonicity loss']:.2f}\n"
    )
    # plots
    if not metrics_only:
        fig_df = plotND(pred_df.cpu().numpy(), gt_df.cpu().numpy())
        fig_eflux = plotND(pred_eflux.cpu().numpy(), gt_eflux.cpu().numpy())
        fig_potens = plotND(
            pred_phi.cpu().numpy(),
            gt_phi.cpu().numpy(),
            n=3,
            cmap="plasma",
            aspect=2,
            aggregate="slice",
        )
        fig_diag = plot_diag([gt_diag], [pred_diag])
        return fig_df, fig_eflux, fig_potens, fig_diag
